Skip to content

Introduce PRNG to SimState and add reproducibility docs.#460

Open
CompRhys wants to merge 10 commits intomainfrom
prng-simstate-reproduce
Open

Introduce PRNG to SimState and add reproducibility docs.#460
CompRhys wants to merge 10 commits intomainfrom
prng-simstate-reproduce

Conversation

@CompRhys
Copy link
Member

@CompRhys CompRhys commented Feb 21, 2026

The only messy bit is resumption as for serialization it seems that the only way to do it is with torch.save and I feel asking the user to store the pickle manually is awkward.


AI Overview

Every SimState now carries an optional _rng field (a torch.Generator) that controls all stochastic operations: momentum initialization, Langevin OU noise, V-Rescale Gamma draws, and C-Rescale barostat noise. No integrator init or step function accepts a seed or prng argument anymore — seeding is done exclusively through the state.

The rng property

state.rng = 42          # int → coerced to a seeded Generator
state.rng = gen         # torch.Generator used directly
state.rng = None        # reset; next access creates an unseeded Generator
samples = state.rng     # lazily initialises if _rng is None/int, then returns it
  • Lazy: if _rng is None (the default), accessing state.rng creates a new torch.Generator on the state's device and stores it back. No Generator is allocated until first use.
  • Coercing: if _rng is an int, accessing state.rng converts it to a seeded Generator via coerce_prng() and stores it back, so subsequent accesses return the same (advancing) Generator.
  • Advancing: because a single torch.Generator object is stored, its internal state advances with each draw, giving a proper random stream rather than re-seeding every step.

Cloning

state.clone() deep-copies the Generator via get_state() / set_state(), producing an independent copy with identical initial RNG state. Drawing from one does not affect the other.

Splitting

state.split() copies global attributes (including _rng) to every piece. All resulting single-system states share the same Generator value (copied), not the same object.

Concatenating

concatenate_states([s1, s2, ...]) takes global attributes from the first state. The resulting batch uses s1's Generator; other states' Generators are discarded.

Device movement

state.to(device) moves the Generator to the target device via coerce_prng(), which creates a new Generator on the target device and copies the RNG state if devices differ.

Serialisation

torch.save(state.rng.get_state(), "rng.pt")           # save
gen = torch.Generator(device=state.device)
gen.set_state(torch.load("rng.pt"))
state.rng = gen                                        # restore

What changed

  • _rng moved from MDState to SimState (it's a global attribute, not MD-specific).
  • All seed= / prng= parameters removed from integrator init functions.
  • initialize_momenta takes generator: torch.Generator | None directly.
  • V-Rescale Gamma sampling switched from torch.distributions.Gamma (unseeded) to torch._standard_gamma(..., generator=rng) so it's now fully seedable.
  • _rattle_sim_state in testing.py refactored to use state.rng instead of saving/restoring global RNG state.
  • coerce_prng handles cross-device Generator transfer.
  • _state_to_device handles Generator device movement.

@CompRhys CompRhys added api API design discussions breaking Breaking changes keep-open PRs to be ignored by StaleBot labels Feb 21, 2026
@CompRhys CompRhys requested a review from thomasloux February 21, 2026 20:05
c2 = torch.sqrt(kT * (1 - torch.square(c1))).unsqueeze(-1)

# Generate random noise from normal distribution
noise = torch.randn_like(state.momenta, device=state.device, dtype=state.dtype)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after pytorch/pytorch#165865 randn_like is in torch 2.10 but I am not sure we want to pin to 2.10 given not all the models people want to use will support.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CompRhys CompRhys marked this pull request as ready for review February 21, 2026 20:35
weibull = torch.distributions.weibull.Weibull(scale=0.1, concentration=1)
rnd = torch.randn_like(sim_state.positions)
rnd = rnd / torch.norm(rnd, dim=-1, keepdim=True)
shifts = weibull.sample(rnd.shape).to(device=sim_state.positions.device) * rnd
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avoid weibull.sample() as it cannot be seeded.

# Generate random numbers
r1 = torch.randn(n_systems, device=device, dtype=dtype)
# Sample Gamma((dof - 1)/2, 1/2) = \sum_2^{dof} X_i^2 where X_i ~ N(0,1)
r2 = torch.distributions.Gamma((dof - 1) / 2, torch.ones_like(dof) / 2).sample()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avoid Gamma.sample() as it cannot be seeded.


@staticmethod
def _clone_attr(value: object) -> object:
"""Clone a single attribute value, handling torch.Generator specially."""
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forks have identical rng states.



def calculate_momenta(
def initialize_momenta(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

driveby: this was a misleading name.

@abhijeetgangan
Copy link
Collaborator

LGTM. Note that setting deterministic mode can have performance penalties. Also, use of text vs binary file formats for restarting can have an effect.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

api API design discussions breaking Breaking changes keep-open PRs to be ignored by StaleBot

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add a page in docs about reproducibility Add a seed for integrator step function to reproduce results Allow seeds to be set for individual batches

2 participants